import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as sig
from util.feature import feature_extraction


# 计算肌电能量信号
def ca_ms_energy(signal, fs):
    signal_length = len(signal)
    power_signal = np.zeros((signal_length, ))
    n = round(0.032*fs)
    m = round(0.016*fs)
    num_of_frame = round(signal_length/m)
    rest = signal_length - num_of_frame*m
    # print(signal_length)
    # print(num_of_frame, rest)
    for epoch in range(num_of_frame):
        power_signal[epoch*m:epoch*m+n] = signal[epoch*m:epoch*m+n] * signal[epoch*m:epoch*m+n]
    power_signal[num_of_frame*m:num_of_frame*m+rest] = signal[num_of_frame*m:num_of_frame*m+rest]*signal[num_of_frame*m:num_of_frame*m+rest]
    return power_signal


# 滑动平均滤波
def smooth_signal(signal):
    abs_data = abs(signal)
    n = round(1000 * 0.032)
    return np.convolve(abs_data, np.ones((n,))/n, mode='same')


# 归一化
def normalization(signal):
    y_max = max(signal)
    y_min = min(signal)
    return (signal - y_min) / (y_max - y_min)


# 巴特沃斯带通滤波器20~500Hz
def filter_signal(signal):
    wn = [20*0.5/1000, 500*0.5/1000]
    b, a = sig.butter(4, wn, btype='bandpass')
    filter_data = sig.filtfilt(b, a, np.ravel(signal))
    return filter_data


# 计算真实EMS
def calculate_ems(analog):
    return (analog - 512) / 1000


def enframe(signal, nw, inc):
    """将音频信号转化为帧。
    参数含义：
    signal:原始音频型号
    nw:每一帧的长度(这里指采样点的长度，即采样频率乘以时间间隔)
    inc:相邻帧的间隔（同上定义）
    """
    signal_length = len(signal)  # 信号总长度
    if signal_length <= nw:  # 若信号长度小于一个帧的长度，则帧数定义为1
        nf = 1
    else:  # 否则，计算帧的总长度
        nf = int(np.ceil((1.0 * signal_length - nw + inc) / inc))
    pad_length = int((nf - 1) * inc + nw)  # 所有帧加起来总的铺平后的长度
    zeros = np.zeros((pad_length - signal_length,))  # 不够的长度使用0填补，类似于FFT中的扩充数组操作
    pad_signal = np.concatenate((signal, zeros))  # 填补后的信号记为pad_signal
    indices = np.tile(np.arange(0, nw), (nf, 1)) + np.tile(np.arange(0, nf * inc, inc),
                                                           (nw, 1)).T  # 相当于对所有帧的时间点进行抽取，得到nf*nw长度的矩阵
    indices = np.array(indices, dtype=np.int32)  # 将indices转化为矩阵
    frames = pad_signal[indices]  # 得到帧信号
    return frames


def endpoint_detection(data, fs, num_of_channel=1):
    data = data / np.max(np.abs(data))
    nx = len(data)
    start_list = []
    end_list = []

    # 这里使用的是加窗的方式，即将原来的数据分成多个小窗，并添加一定的窗移，这样就能将这段信号从头到尾处理一边遍
    # 常数设置
    FrameLen = round(fs * 0.032)  # 窗口的长度为0.032 * fs, 可按实际情况而定
    FrameInc = round(fs * 0.016)  # 帧位移长度为0.016 * fs, 可按实际情况而定
    energy = 0.02  # 初始短时能量低门限
    sum_s = 0.3

    # enframe为分帧函数，是对x(1:end-1)分帧，帧长FrameLen，帧移FrameInc,
    channel_1 = enframe(data[:len(data)-1], FrameLen, FrameInc);
    nf = round(channel_1.size / FrameLen)
    channel_1 = channel_1.reshape((nf, FrameLen))
    energy_frame = np.zeros((nf,))
    sum_s_frame = np.zeros((nf,))
    channel_energy = np.zeros((1, num_of_channel))
    channel_sum_s = np.zeros((1, num_of_channel))
    for i in range(nf-1):
        channel_energy[0, 0] = np.sum(channel_1[i, :]*channel_1[i, :])
        energy_frame[i] = np.sum(channel_energy)/FrameLen
        channel_sum_s[0, 0] = np.sum((channel_1[i, :]-np.mean(channel_1[i, :]))*(channel_1[i, :]-np.mean(channel_1[i, :])))
        sum_s_frame[i] = np.sum(channel_sum_s)/FrameLen
    # 调整能量门限
    energy = min(energy, 0.02*max(energy_frame))
    sum_s1 = min(sum_s, 0.01*max(sum_s_frame))
    sum_s2 = min(sum_s, 0.028*max(sum_s_frame))
    # 开始端点检测
    x1 = 0
    x2 = 0
    x1_find_flag = False
    while x2 <= nf-1:
        for k in range(x2, nf-1):
            if energy_frame[k] > energy and sum_s_frame[k] > sum_s2:
                x1 = k
                start_list.append(x1)
                x1_find_flag = True
                break
            else:
                x1_find_flag = False
        if x1_find_flag:
            for m in range(x1, nf-1):
                if energy_frame[m] < energy and sum_s_frame[m] < sum_s1:
                    x2 = m
                    end_list.append(x2)
                    x1_find_flag = False
                    break
        if len(end_list)!=0 and len(start_list)>len(end_list):
            start_list.pop()
            break
        if len(end_list) == 0 and len(start_list)!= 0:
            start_list.pop()
            break
        x2 += 1
    if len(start_list)>=2 and len(end_list) >= 2:
        for i, start_point in enumerate(start_list):
            if i != 0 and start_list[i]-end_list[i-1] <= 30:
                start_list[i] = start_list[i-1]
                end_list[i-1] = end_list[i]
    start_list = set(start_list)
    end_list = set(end_list)
    return list(start_list), list(end_list)


def get_feature(signal, fs):
    ems_signal = calculate_ems(signal)
    start_list, end_list = endpoint_detection(ems_signal, fs)
    start_list = sorted(start_list)
    end_list = sorted(end_list)
    active_frame = []
    ems_features = []
    frame = []
    # 用解算得到的起点、终点来截取数据
    for start_point, end_point in zip(start_list, end_list):
        frame = ems_signal[round((start_point - 1) * 0.016 * fs): round((end_point - 2) * 0.016 * fs + fs * 0.032)]
        active_frame.append(frame)
    if len(frame) > 0:
        print('检测到的活动端点为:')
        print(start_list)
        print(end_list)
        filter_active_frame = []
        normalization_frame = []
        smooth_frame = []
        power_frame = []
        for i in range(len(active_frame)):
            power_frame.append(ca_ms_energy(active_frame[i], fs))
            filter_active_frame.append(filter_signal(power_frame[i]))
            normalization_frame.append(normalization(filter_active_frame[i]))
            smooth_frame.append(smooth_signal(normalization_frame[i]))
        for i in range(len(filter_active_frame)):
            obj_feature = feature_extraction(filter_active_frame[i])
            ems_features.append(obj_feature.get_4_feature())
    else:
        print('EMS 没有检测到的活动端点')
        return None
    return ems_features


if __name__ == '__main__':
    fs = 1000
    data = np.loadtxt('../data/ems_data.txt')
    features = get_feature(data, fs)
    print(features)

# if __name__ == '__main__':
#     fs = 1000
#     data = np.loadtxt('../data/ems_data.txt')
#     data = calculate_ems(data)
#     start_list, end_list = endpoint_detection(data, fs)
#     start_list = sorted(start_list)
#     end_list = sorted(end_list)
#     active_frame = []
#     # 用解算得到的起点、终点来截取数据
#     for start_point, end_point in zip(start_list, end_list):
#         frame = data[round((start_point-1)*0.016*fs): round((end_point-2)*0.016*fs+fs*0.032)]
#         active_frame.append(frame)
#     filter_active_frame = []
#     normalization_frame = []
#     smooth_frame = []
#     power_frame = []
#     for i in range(len(active_frame)):
#         # filter_active_frame.append(filter_signal(active_frame[i]))
#         # normalization_frame.append(normalization(filter_active_frame[i]))
#         # smooth_frame.append(smooth_signal(normalization_frame[i]))
#         power_frame.append(ca_ms_energy(active_frame[i], fs))
#         filter_active_frame.append(filter_signal(power_frame[i]))
#         normalization_frame.append(normalization(filter_active_frame[i]))
#         smooth_frame.append(smooth_signal(normalization_frame[i]))
#         # plt.figure(i+1)
#         # plt.plot(smooth_frame[i])
#     for i in range(len(filter_active_frame)):
#         obj_feature = feature_extraction(filter_active_frame[i])
#         print('frame ' + str(i+1) + ':')
#         print('MAV, VAR, ZC, RMS, WA, MNP, MNF, MF')
#         print(obj_feature.get_feature())
#     # plt.show()
#

